Source code for hysop.operator.base.spectral_operator
# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import math
import os
import sympy as sm
import numpy as np
from hysop.constants import (
BoundaryCondition,
BoundaryExtension,
TransformType,
MemoryOrdering,
TranspositionState,
Backend,
SpectralTransformAction,
Implementation,
)
from hysop.tools.misc import compute_nbytes
from hysop.tools.htypes import check_instance, to_tuple, first_not_None, to_set
from hysop.tools.decorators import debug
from hysop.tools.units import bytes2str
from hysop.tools.numerics import (
is_fp,
is_complex,
complex_to_float_dtype,
float_to_complex_dtype,
determine_fp_types,
)
from hysop.tools.io_utils import IOParams
from hysop.tools.spectral_utils import (
SpectralTransformUtils as STU,
EnergyPlotter,
EnergyDumper,
)
from hysop.core.arrays.array_backend import ArrayBackend
from hysop.core.arrays.array import Array
from hysop.core.memory.memory_request import MemoryRequest, OperatorMemoryRequests
from hysop.core.graph.graph import (
not_initialized as _not_initialized,
initialized as _initialized,
discretized as _discretized,
ready as _ready,
)
from hysop.core.graph.computational_node_frontend import ComputationalGraphNodeFrontend
from hysop.topology.cartesian_descriptor import CartesianTopologyDescriptors
from hysop.parameters.buffer_parameter import BufferParameter
from hysop.fields.continuous_field import Field, ScalarField, TensorField
from hysop.symbolic.array import SymbolicArray
from hysop.symbolic.spectral import (
WaveNumber,
SpectralTransform,
AppliedSpectralTransform,
)
from hysop.numerics.fft.fft import (
FFTI,
simd_alignment,
is_byte_aligned,
HysopFFTWarning,
)
[docs]
class SpectralComputationalGraphNodeFrontend(ComputationalGraphNodeFrontend):
def __new__(cls, implementation, enforce_implementation=True, **kwds):
return super().__new__(cls, implementation=implementation, **kwds)
def __init__(self, implementation, enforce_implementation=True, **kwds):
impl, extra_kwds = self.get_actual_implementation(
implementation=implementation,
enforce_implementation=enforce_implementation,
**kwds,
)
for k in extra_kwds.keys():
assert k not in kwds
kwds.update(extra_kwds)
super().__init__(implementation=impl, **kwds)
[docs]
@classmethod
def get_actual_implementation(
cls, implementation, enforce_implementation=True, cl_env=None, **kwds
):
"""
Parameters
----------
implementation: Implementation, optional, defaults to None
User desired target implementation.
enforce_implementation: bool, optional, defaults to True
If this is set to True, input implementation is enforced.
Else, this function may select another implementation when some conditions are met:
Case 1: Host FFT by mapping CPU OpenCL buffers
Conditions:
a/ input implementation is set to OPENCL
b/ cl_env.device is of type CPU
c/ Implementation.PYTHON is a valid operator implementation
d/ Target python operator supports OPENCL as backend
e/ OpenCL platform has zero copy capabilities (cannot be checked)
=> If cl_env is not given, this will yield a RuntimeError
=> In this case PYTHON implementation is chosen instead.
Buffer are mapped to host.
By default this should give multithread FFTW + multithreaded numba.
For all other cases, this parameter is ignored.
Notes
-----
clFFT (gpyFFT) support for OpenCL CPU devices is a bit neglected.
This function allows to override the implementation target from
OPENCL to PYTHON when a CPU OpenCL environment is given as input.
By default, the CPU FFT target is FFTW (pyFFTW) which has much
better support (multithreaded fftw + multithreaded numba).
OpenCL buffers are mapped to host memory with enqueue_map_buffer
(this makes the assumption thal all OpenCL buffers have been allocated
with zero-copy capability in the target OpenCL platform).
"""
implementation = first_not_None(implementation, cls.default_implementation())
assert implementation in cls.implementations()
extra_kwds = {}
if enforce_implementation:
return (implementation, extra_kwds)
if implementation == Implementation.OPENCL:
if cl_env is None:
msg = "enforce_implementation was set to False, "
msg += "implementation is OPENCL, but no cl_env was passed "
msg += "to check if the device is of type CPU."
raise RuntimeError(msg)
from hysop.backend.device.opencl import cl
if cl_env.device.type == cl.device_type.CPU:
if Implementation.PYTHON in cls.implementations():
from hysop.backend.host.host_operator import (
HostOperator,
OpenClMappable,
)
op_cls = cls.implementations()[Implementation.PYTHON]
if not issubclass(op_cls, HostOperator):
msg = "Operator {} is not a HostOperator."
msg = msg.format(op_cls)
raise TypeError(msg)
if not issubclass(op_cls, OpenClMappable):
msg = "Operator {} does not support host to device opencl buffer mapping."
msg = msg.format(op_cls)
raise TypeError(msg)
assert Backend.HOST in op_cls.supported_backends()
assert Backend.OPENCL in op_cls.supported_backends()
extra_kwds["enable_opencl_host_buffer_mapping"] = True
return (Implementation.PYTHON, extra_kwds)
return (implementation, extra_kwds)
[docs]
class SpectralOperatorBase:
"""
Common implementation interface for spectral based operators.
"""
min_fft_alignment = simd_alignment # FFTW SIMD.
@debug
def __new__(cls, fft_interface=None, fft_interface_kwds=None, **kwds):
return super().__new__(cls, **kwds)
@debug
def __init__(self, fft_interface=None, fft_interface_kwds=None, **kwds):
"""
Initialize a spectral operator base.
kwds: dict
Base class keyword arguments.
"""
super().__init__(**kwds)
check_instance(fft_interface, FFTI, allow_none=True)
check_instance(fft_interface_kwds, dict, allow_none=True)
self.transform_groups = {} # dict[tag] -> SpectralTransformGroup
# those values will be deleted at discretization
self._fft_interface = fft_interface
self._fft_interface_kwds = fft_interface_kwds
@property
def backend(self):
msg = "FFT array backend depends on the transform group. Please use op.transform_group[key].backend instead."
raise AttributeError(msg)
@property
def FFTI(self):
msg = "FFT interface depends on the transform group. Please use op.transform_group[key].FFTI instead."
raise AttributeError(msg)
[docs]
def new_transform_group(self, tag=None, mem_tag=None):
"""
Register a new SpectralTransformGroup to this spectral operator.
A SpectralTransformGroup is an object that collect forward and
backward field transforms as well as symbolic expressions and
wave_numbers symbols.
"""
n = len(self.transform_groups)
tag = first_not_None(tag, f"transform_group_{n}")
msg = 'Tag "{}" has already been registered.'
assert tag not in self.transform_groups, msg.format(tag)
trg = SpectralTransformGroup(op=self, tag=tag, mem_tag=mem_tag)
self.transform_groups[tag] = trg
return trg
[docs]
def pre_initialize(self, **kwds):
output_parameters = set()
for tg in self.transform_groups.values():
output_parameters.update(tg.output_parameters)
for p in output_parameters:
self.output_params.update({p})
[docs]
def initialize(self, **kwds):
super().initialize(**kwds)
for tg in self.transform_groups.values():
tg.initialize(**kwds)
[docs]
def get_field_requirements(self):
requirements = super().get_field_requirements()
for is_input, (field, td, req) in requirements.iter_requirements():
req.memory_order = MemoryOrdering.C_CONTIGUOUS
req.axes = (TranspositionState[field.dim].default_axes(),)
can_split = req.can_split
can_split[-1] = False
can_split[:-1] = True
req.can_split = can_split
return requirements
[docs]
@debug
def get_node_requirements(self):
node_reqs = super().get_node_requirements()
node_reqs.enforce_unique_topology_shape = True
return node_reqs
[docs]
def discretize(self, **kwds):
super().discretize(**kwds)
comm = self.mpi_params.comm
size = self.mpi_params.size
if size > 1:
msg = "\n[FATAL ERROR] Spectral operators do not support the MPI interface yet."
msg += "\nPlease use the Fortran FFTW interface if possible or "
msg += "use another discretization method for operator {}.\n"
msg = msg.format(self.node_tag)
print(msg)
raise NotImplementedError
for tg in self.transform_groups.values():
tg.discretize(
fft_interface=self._fft_interface,
fft_interface_kwds=self._fft_interface_kwds,
enable_opencl_host_buffer_mapping=self.enable_opencl_host_buffer_mapping,
**kwds,
)
del self._fft_interface
del self._fft_interface_kwds
[docs]
def get_mem_requests(self, **kwds):
memory_requests = {}
for tg in self.transform_groups.values():
for k, v in tg.get_mem_requests(**kwds).items():
check_instance(k, str) # temporary buffer name
check_instance(v, (int, np.integer)) # nbytes
K = (k, tg.backend)
if K in memory_requests:
memory_requests[K] = max(memory_requests[K], v)
else:
memory_requests[K] = v
return memory_requests
[docs]
def get_work_properties(self, **kwds):
requests = super().get_work_properties(**kwds)
for (k, backend), v in self.get_mem_requests(**kwds).items():
check_instance(k, str)
check_instance(v, (int, np.integer))
if v > 0:
mrequest = MemoryRequest(
backend=backend, size=v, alignment=self.min_fft_alignment
)
requests.push_mem_request(request_identifier=k, mem_request=mrequest)
return requests
[docs]
def setup(self, work):
self.allocate_tmp_fields(work)
for tg in self.transform_groups.values():
tg.setup(work=work)
super().setup(work=work)
[docs]
class SpectralTransformGroup:
"""
Build and check a FFT transform group.
This object tells the planner to build a full forward transform for all given
forward_fields. The planner will also build backward transforms for all specified
backward_fields.
The object will also automatically build per-axis wavenumbers up to certain powers,
extracted from user provided sympy expressions.
Finally boundary condition (ie. transform type) compability will be checked by
using user provided sympy expressions.
Calling a forward transform ensures that forward source field is read-only
and not destroyed.
"""
DEBUG = False
def __new__(cls, op, tag, mem_tag, **kwds):
return super().__new__(cls, **kwds)
def __init__(self, op, tag, mem_tag, **kwds):
"""
Parameters
----------
op : SpectralOperatorBase
Operator that creates this SpectralTransformGroup.
tag: str
A tag to identify this transform group.
Each tag can only be registered once in a SpectralOperatorBase instance.
Attributes:
-----------
tag: str
SpectralTransformGroup identifier.
mem_tag: str
SpectralTransformGroup memory pool identifier.
forward_transforms: list of forward SpectralTransform
Forward fields to be planned for transform, according to Field boundary conditions.
backward_fields: list of backward SpectralTransform
Backward fields to be planned for transform, according to Field boundary conditions.
Notes
-----
All forward_fields and backward_fields have to live on the same domain and
their boundary conditions should comply with given expressions.
"""
super().__init__(**kwds)
mem_tag = first_not_None(mem_tag, "fft_pool")
check_instance(op, SpectralOperatorBase)
check_instance(tag, str)
check_instance(mem_tag, str)
self._op = op
self._tag = tag
self._mem_tag = mem_tag
self._forward_transforms = {}
self._backward_transforms = {}
self._wave_numbers = set()
self._indexed_wave_numbers = {}
self._expressions = ()
self._discrete_wave_numbers = None
[docs]
def indexed_wavenumbers(self, *wave_numbers):
return tuple(self._indexed_wave_numbers[Wi] for Wi in wave_numbers)
@property
def op(self):
return self._op
@property
def tag(self):
return self._tag
@property
def mem_tag(self):
return self._mem_tag
@property
def name(self):
return self._tag
@property
def initialized(self):
return self._op.initialized
@property
def discretized(self):
return self._op.discretized
@property
def ready(self):
return self._op.ready
@property
def forward_fields(self):
return tuple(map(lambda x: x[0], self._forward_transforms.keys()))
@property
def backward_fields(self):
return tuple(map(lambda x: x[0], self._backward_transforms.keys()))
@property
def forward_transforms(self):
return self._forward_transforms
@property
def backward_transforms(self):
return self._backward_transforms
[docs]
@_not_initialized
def initialize(
self, fft_granularity=None, fft_concurrent_plans=1, fft_plan_workload=1, **kwds
):
"""
Should be called after all require_forward_transform and require_backward_transform
calls.
Parameters
----------
fft_granularity: int, optional
Granularity of each directional fft plan.
1: iterate over 1d lines (slices of dimension 1)
2: iterate over 2d planes (slices of dimension 2)
3: iterate over 3d blocks (slices of dimension 3)
n-1: iterate over hyperplans (slices of dimension n-1)
n : no iteration, the plan will handle the whole domain.
Contiguous buffers with sufficient alignement are allocated.
Default value is: 1 in 1D else n-1 (ie. hyperplans)
fft_plan_workload: int, optional, defaults to 1
The number of blocks of dimension fft_granularity that a
single plan will handle at once. Default is one block.
fft_concurrent_plans: int, optional, defaults to 1
Number of concurrent plans.
Should be 1 for HOST based FFT interfaces.
Should be at least 3 for DEVICE based FFT interface if the device
has two async copy engine (copy, transform, copy).
"""
(domain, dim) = self.check_fields(self.forward_fields, self.backward_fields)
fft_granularity = first_not_None(fft_granularity, max(1, dim - 1))
check_instance(fft_granularity, int, minval=1, maxval=dim)
check_instance(fft_concurrent_plans, int, minval=1)
check_instance(fft_plan_workload, int, minval=1)
self._fft_granularity = fft_granularity
self._fft_concurrent_plans = fft_concurrent_plans
self._fft_plan_workload = fft_plan_workload
self._domain = domain
self._dim = dim
[docs]
@_initialized
def discretize(
self,
fft_interface,
fft_interface_kwds,
enable_opencl_host_buffer_mapping,
**kwds,
):
backends = set()
grid_resolutions = set()
compute_axes = set()
compute_shapes = set()
compute_dtypes = set()
for fwd in self.forward_transforms.values():
fwd.discretize()
backends.add(fwd.backend)
grid_resolutions.add(to_tuple(fwd.dfield.mesh.grid_resolution))
compute_axes.add(fwd.output_axes)
compute_shapes.add(fwd.output_shape)
compute_dtypes.add(fwd.output_dtype)
for bwd in self.backward_transforms.values():
bwd.discretize()
backends.add(bwd.backend)
grid_resolutions.add(to_tuple(bwd.dfield.mesh.grid_resolution))
compute_axes.add(bwd.input_axes)
compute_shapes.add(bwd.input_shape)
compute_dtypes.add(bwd.input_dtype)
def format_error(data):
return "\n *" + "\n *".join(str(x) for x in data)
msg = "Fields do not live on the same backend:" + format_error(backends)
assert len(backends) == 1, msg
msg = "Fields grid size mismatch:" + format_error(grid_resolutions)
assert len(grid_resolutions) == 1, msg
assert len(compute_axes) == 1, "Fields axes mismatch:" + format_error(
compute_axes
)
assert len(compute_shapes) == 1, "Fields shape mismatch:" + format_error(
compute_shapes
)
assert len(compute_dtypes) == 1, "Fields data type mismatch." + format_error(
compute_dtypes
)
backend = next(iter(backends))
grid_resolution = next(iter(grid_resolutions))
compute_axes = next(iter(compute_axes))
compute_shape = next(iter(compute_shapes))
compute_dtype = next(iter(compute_dtypes))
if enable_opencl_host_buffer_mapping:
msg = "Trying to enable opencl device to host buffer mapping on {} target."
assert backend.kind is Backend.OPENCL, msg.format(backend.kind)
if fft_interface is None:
fft_interface_kwds = first_not_None(fft_interface_kwds, {})
fft_interface = FFTI.default_interface_from_backend(
backend,
enable_opencl_host_buffer_mapping=enable_opencl_host_buffer_mapping,
**fft_interface_kwds,
)
else:
assert not fft_interface_kwds, "FFT interface has already been built."
check_instance(fft_interface, FFTI)
fft_interface.check_backend(
backend, enable_opencl_host_buffer_mapping=enable_opencl_host_buffer_mapping
)
buffer_backend = backend
host_backend = backend.host_array_backend
backend = fft_interface.backend
discrete_wave_numbers = {}
for wn in self._wave_numbers:
(idx, freqs, nd_freqs) = self.build_wave_number(
self._domain,
grid_resolution,
backend,
wn,
compute_dtype,
compute_axes,
compute_shape,
)
self._indexed_wave_numbers[wn].indexed_object.to_backend(
backend.kind
).bind_memory_object(freqs)
self._indexed_wave_numbers[wn].index.bind_axes(compute_axes)
discrete_wave_numbers[wn] = (idx, freqs, nd_freqs)
self._discrete_wave_numbers = discrete_wave_numbers
self.buffer_backend = buffer_backend
self.host_backend = host_backend
self.backend = backend
self.FFTI = fft_interface
self.grid_resolution = grid_resolution
self.compute_axes = compute_axes
self.compute_shape = compute_shape
self.compute_dtype = compute_dtype
[docs]
@classmethod
def build_wave_number(
cls,
domain,
grid_resolution,
backend,
wave_number,
compute_dtype,
compute_axes,
compute_resolution,
):
dim = domain.dim
length = domain.length
ftype, ctype = determine_fp_types(compute_dtype)
axis = wave_number.axis
transform = wave_number.transform
exponent = wave_number.exponent
idx = compute_axes.index(axis)
L = domain.length[axis]
N = grid_resolution[axis]
freqs = STU.compute_wave_numbers(transform=transform, N=N, L=L, ftype=ftype)
freqs = freqs**exponent
if STU.is_R2R(transform):
sign_offset = STU.is_cosine(transform)
freqs *= (-1) ** ((exponent + sign_offset) // 2)
assert exponent != 0, "exponent cannot be zero."
assert exponent > 0, "negative powers not implemented yet."
if is_complex(freqs.dtype) and (exponent % 2 == 0):
assert freqs.imag.sum() == 0
freqs = freqs.real.copy()
backend_freqs = backend.empty_like(freqs)
backend_freqs[...] = freqs
freqs = backend_freqs
nd_shape = [
1,
] * dim
nd_shape[idx] = freqs.size
nd_shape = tuple(nd_shape)
nd_freqs = freqs.reshape(nd_shape)
if cls.DEBUG:
print()
print("BUILD WAVENUMBER")
print(f"backend: {backend.kind}")
print(f"grid_shape: {grid_resolution}")
print(f"length: {length}")
print("-----")
print(f"ftype: {ftype}")
print(f"ctype: {ctype}")
print(f"compute shape: {compute_resolution}")
print(f"compute axes: {compute_axes}")
print("-----")
print("wave_number:")
print(f" *symbolic: {wave_number}")
print(f" *axis: {axis}")
print(f" *transform: {transform}")
print(f" *exponent: {exponent}")
print("----")
print(f"L: {L}")
print(f"N: {N}")
print(f"freqs: {freqs}")
print(f"nd_freqs: {nd_freqs}")
print("----")
return (idx, freqs, nd_freqs)
[docs]
@_discretized
def get_mem_requests(self, **kwds):
memory_requests = {}
for fwd in self.forward_transforms.values():
mem_requests = fwd.get_mem_requests(**kwds)
check_instance(mem_requests, dict, keys=str, values=(int, np.integer))
for k, v in mem_requests.items():
if k in memory_requests:
memory_requests[k] = max(memory_requests[k], v)
else:
memory_requests[k] = v
for bwd in self.backward_transforms.values():
mem_requests = bwd.get_mem_requests(**kwds)
check_instance(mem_requests, dict, keys=str, values=(int, np.integer))
for k, v in mem_requests.items():
if k in memory_requests:
memory_requests[k] = max(memory_requests[k], v)
else:
memory_requests[k] = v
return memory_requests
[docs]
@_discretized
def setup(self, work):
for fwd in self.forward_transforms.values():
fwd.setup(work=work)
for bwd in self.backward_transforms.values():
bwd.setup(work=work)
[docs]
@_not_initialized
def require_forward_transform(
self,
field,
axes=None,
transform_tag=None,
custom_output_buffer=None,
action=None,
dump_energy=None,
plot_energy=None,
**kwds,
):
"""
Tells this SpectralTransformGroup to build a forward SpectralTransform
on given field. Only specified axes are transformed.
Boundary condition to FFT extension mapping:
Periodic: Periodic extension
Homogeneous Dirichlet: Odd extension
Homogeneous Neumann: Even extension
This leads to 5 possible transforms for each axis (periodic-periodic, even-even,
odd-odd, even-odd, odd-even).
Forward transforms used for each axis per extension pair:
*Periodic-Periodic (PER-PER): DFT (C2C, R2C for the first periodic axis)
*Dirichlet-Dirichlet (ODD-ODD): DST-I
*Dirichlet-Neumann (ODD-EVEN): DST-III
*Neumann-Dirichlet (EVEN-ODD): DCT-III
*Neumann-Neumann (EVEN-EVEN): DCT-I
This method will return the SpectralTransform object associated to field.
Parameters
----------
field: ScalarField
The source field to be transformed.
axes: array-like of integers
The axes to be transformed.
transform_tag: str
Extra tag to register the forward transform (a single scalar field can be
transformed multiple times). Default tag is 'default'.
custom_output_buffer: None or str, optional
Force this transform to output in one of the two common transform group buffers.
Default None value will force the user allocate an output buffer.
Specifying 'B0' or 'B1' will tell the planner to output the transform
in one of the two transform group buffers (that are used during all forward
and backward transforms of the same transform group). This features allow
FFT operators to save one buffer for the last forward transform.
Specifying 'auto' will tell the planner to choose either 'B0' or 'B1'.
action: BackwardTransfromAction, optional
Defaults to SpectralTransformAction.OVERWRITE which will overwrite the
compute slices of the output buffer.
SpectralTransformAction.ACCUMULATE will sum the current content of the buffer
with the result of the forward transform.
dump_energy: IOParams, optional, defaults to None
Compute the energy for each wavenumber at given frequency after each transform.
If None is passed, no files are generated (default behaviour).
plot_energy: IOParams, optional, defaults to None
Plot field energy after each call to the forward transform to a custom file.
If None is passed, no plots are generated (default behaviour).
compute_energy_frequencies: array like of integers, optional, defaults to None
Extra frequencies where to compute energy.
Notes
-----
IOParams filename is formatted before being used:
{fname} is replaced with discrete field name
{ite} is replaced with simulation iteration id for plotting and '' for file dumping.
dump_energy plot_energy result
None None nothing
iop0 0 energy is computed and dumped every iop0.frequency iterations
0 iop1 energy is computed and dumped every iop1.frequency iterations
iop0 iop1 energy is computed every iop1.frequency and iop2.frequency iterations
dumped every iop0.frequency
plotted every iop1.frequency
About frequency:
if (frequency<0) no dump
if (frequency==0) dump at time of interests and last iteration
if (frequency>=0) dump at time of interests, last iteration and every freq iterations
"""
transform_tag = first_not_None(transform_tag, "default")
action = first_not_None(action, SpectralTransformAction.OVERWRITE)
transforms = SpectralTransform(field=field, axes=axes, forward=False)
check_instance(field, Field)
check_instance(transform_tag, str)
check_instance(action, SpectralTransformAction)
transforms = SpectralTransform(field=field, axes=axes, forward=True)
msg = 'Field {} with axes {} and transform_tag "{}" has already been registered for forward transform.'
if field.is_tensor:
planned_transforms = field.new_empty_array()
for idx, f in field.nd_iter():
assert (
f,
axes,
transform_tag,
) not in self._forward_transforms, msg.format(
f.name, axes, transform_tag
)
assert f in self._op.input_fields
assert f is transforms[idx].field
assert transforms[idx].is_forward
planned_transforms[idx] = PlannedSpectralTransform(
transform_group=self,
tag=self.tag + "_" + transform_tag + "_" + f.name,
symbolic_transform=transforms[idx],
custom_output_buffer=custom_output_buffer,
action=action,
dump_energy=dump_energy,
plot_energy=plot_energy,
**kwds,
)
self._forward_transforms[(f, axes, transform_tag)] = planned_transforms[
idx
]
else:
assert (
field,
axes,
transform_tag,
) not in self._forward_transforms, msg.format(
field.name, axes, transform_tag
)
assert field in self._op.input_fields
assert field is transforms.field
assert transforms.is_forward
planned_transforms = PlannedSpectralTransform(
transform_group=self,
tag=self.tag + "_" + transform_tag + "_" + field.name,
symbolic_transform=transforms,
custom_output_buffer=custom_output_buffer,
action=action,
dump_energy=dump_energy,
plot_energy=plot_energy,
**kwds,
)
self._forward_transforms[(field, axes, transform_tag)] = planned_transforms
return planned_transforms
[docs]
@_not_initialized
def require_backward_transform(
self,
field,
axes=None,
transform_tag=None,
custom_input_buffer=None,
matching_forward_transform=None,
action=None,
dump_energy=None,
plot_energy=None,
**kwds,
):
"""
Same as require_forward_transform but for backward transforms.
This corresponds to the following backward transform mappings:
if order[axis] is 0:
*no transform -> no transform
else, if order[axis] is even:
*C2C -> C2C
*R2C -> C2R
*DCT-I -> DCT-I
*DCT-III -> DCT-II
*DST-I -> DST-I
*DST-III -> DST-II
else: (if order[axis] is odd)
*C2C -> C2C
*R2C -> C2R
*DCT-I -> DST-I
*DCT-III -> DST-II
*DST-I -> DCT-I
*DST-III -> DCT-II
For backward transforms, boundary compatibility for output_fields is thus the following:
if axis is even:
Boundary should be exactly the same on the axis.
else, if axis is odd, boundary conditions change on this axe:
*(Periodic-Peridic) PER-PER -> PER-PER (Periodic-Periodic)
*(Neumann-Neumann) EVEN-EVEN -> ODD-ODD (Dirichlet-Dirichlet)
*(Neumann-Dirichlet) EVEN-ODD -> ODD-EVEN (Dirichlet-Neumann)
*(Dirichlet-Neumann) ODD-EVEN -> EVEN-ODD (Neumman-Dirichlet)
*(Dirichlet-Dirichlet) ODD-ODD -> EVEN-EVEN (Neumann-Neumann)
Order and boundary conditions are decuded from field.
Parameters
----------
field: ScalarField
The target field where the result of the inverse transform will be stored.
axes: array-like of integers
The axes to be transformed.
transform_tag: str
Extra tag to register the backward transform (a single scalar field can be
transformed multiple times). Default tag is 'default'.
custom_input_buffer: None or str or F, optional
Force this transform to take as input one of the two common transform group buffers.
Default None value will force the user to supply an input buffer.
Specifying 'B0' or 'B1' will tell the planner to take as transform input
one of the two transform group buffers (that are used during all forward
and backward transforms of the same transform group). This features allow
FFT operators to save one buffer for the first backward transform.
Specifying 'auto' will tell the planner to use the matching
transform output buffer.
action: BackwardTransfromAction, optional
Defaults to SpectralTransformAction.OVERWRITE which will overwrite the
compute slices of the given output field.
SpectralTransformAction.ACCUMULATE will sum the current content of the field
with the result of the backward transform.
dump_energy: IOParams, optional, defaults to None
Compute the energy for each wavenumber at given frequency before each transform.
If None is passed, no files are generated (default behaviour).
plot_energy: IOParams, optional, defaults to None
Plot field energy before each call to the backward transform to a custom file.
If None is passed, no plots are generated (default behaviour).
compute_energy_frequencies: array like of integers, optional, defaults to None
Extra frequencies where to compute energy.
Notes
-----
IOParams filename is formatted before being used:
{fname} is replaced with discrete field name
{ite} is replaced with simulation iteration id for plotting and '' for file dumping.
dump_energy plot_energy result
None None nothing
iop0 0 energy is computed and dumped every iop0.frequency iterations
0 iop1 energy is computed and dumped every iop1.frequency iterations
iop0 iop1 energy is computed every iop1.frequency and iop2.frequency iterations
dumped every iop0.frequency
plotted every iop1.frequency
About frequency:
if (frequency<0) no dump
if (frequency==0) dump at time of interests and last iteration
if (frequency>=0) dump at time of interests, last iteration and every freq iterations
"""
transform_tag = first_not_None(transform_tag, "default")
action = first_not_None(action, SpectralTransformAction.OVERWRITE)
check_instance(field, Field)
check_instance(transform_tag, str)
check_instance(action, SpectralTransformAction)
transforms = SpectralTransform(field=field, axes=axes, forward=False)
msg = 'Field {} with axes {} and transform_tag "{}" has already been registered for backward transform.'
if field.is_tensor:
planned_transforms = field.new_empty_array()
for idx, f in field.nd_iter():
assert (
f,
axes,
transform_tag,
) not in self._backward_transforms, msg.format(
f.name, axes, transform_tag
)
assert f in self._op.output_fields
assert not transforms[idx].is_forward
planned_transforms[idx] = PlannedSpectralTransform(
transform_group=self,
tag=self.tag + "_" + transform_tag + "_" + f.name,
symbolic_transform=transforms[idx],
custom_input_buffer=custom_input_buffer,
matching_forward_transform=matching_forward_transform,
action=action,
dump_energy=dump_energy,
plot_energy=plot_energy,
**kwds,
)
self._backward_transforms[(f, axes, transform_tag)] = (
planned_transforms[idx]
)
else:
assert (
field,
axes,
transform_tag,
) not in self._backward_transforms, msg.format(
field.name, axes, transform_tag
)
assert field in self._op.output_fields
assert not transforms.is_forward
planned_transforms = PlannedSpectralTransform(
transform_group=self,
tag=self.tag + "_" + transform_tag + "_" + field.name,
symbolic_transform=transforms,
custom_input_buffer=custom_input_buffer,
matching_forward_transform=matching_forward_transform,
action=action,
dump_energy=dump_energy,
plot_energy=plot_energy,
**kwds,
)
self._backward_transforms[(field, axes, transform_tag)] = planned_transforms
return planned_transforms
@property
def output_parameters(self):
parameters = set()
for pt in tuple(self._forward_transforms.values()) + tuple(
self._backward_transforms.values()
):
parameters.update(pt.output_parameters)
return parameters
@property
def discrete_wave_numbers(self):
assert self.discretized
discrete_wave_numbers = self._discrete_wave_numbers
if discrete_wave_numbers is None:
msg = "discrete_wave_numbers has not been set yet."
raise AttributeError(msg)
return self._discrete_wave_numbers
[docs]
@_not_initialized
def push_expressions(self, *exprs):
exprs_wave_numbers = set()
for expr in exprs:
assert isinstance(expr, sm.Basic)
(e, transforms, wn) = STU.parse_expression(expr, replace_pows=True)
self._expressions += (e,)
self._wave_numbers.update(wn)
for _wn in wn:
if _wn not in self._indexed_wave_numbers:
self._indexed_wave_numbers[_wn] = _wn.indexed_buffer()
exprs_wave_numbers.update(wn)
if self.DEBUG:
print(f"\n\nPARSING EXPRESSION {expr}")
print(f" new_expr: {e}")
print(f" transforms: {transforms}")
print(f" wave_numbers: {wn}")
return tuple(exprs_wave_numbers)
[docs]
@classmethod
def check_fields(cls, forward_fields, backward_fields):
all_fields = tuple(set(forward_fields + backward_fields))
if not all_fields:
msg = "At least one field is required."
raise ValueError(msg)
domain = cls.determine_domain(*all_fields)
dim = domain.dim
return (domain, dim)
[docs]
@classmethod
def determine_domain(cls, *fields):
domain = fields[0].domain
for field in fields[1:]:
if field.domain is not domain:
msg = "Domain mismatch between fields:\n{}\nvs.\n{}\n"
msg = msg.format(domain, field.domain)
raise ValueError(msg)
return domain
[docs]
class PlannedSpectralTransform:
"""
A planned spectral transform is an AppliedSpectralTransform wrapper.
This object will be handled by the transform planner.
"""
DEBUG = False
def __new__(
cls,
transform_group,
tag,
symbolic_transform,
action,
custom_input_buffer=None,
custom_output_buffer=None,
matching_forward_transform=None,
dump_energy=None,
plot_energy=None,
compute_energy_frequencies=None,
**kwds,
):
return super().__new__(cls, **kwds)
def __init__(
self,
transform_group,
tag,
symbolic_transform,
action,
custom_input_buffer=None,
custom_output_buffer=None,
matching_forward_transform=None,
dump_energy=None,
plot_energy=None,
compute_energy_frequencies=None,
**kwds,
):
super().__init__(**kwds)
check_instance(transform_group, SpectralTransformGroup)
check_instance(transform_group.op, SpectralOperatorBase)
check_instance(tag, str)
check_instance(symbolic_transform, AppliedSpectralTransform)
check_instance(action, SpectralTransformAction)
check_instance(dump_energy, IOParams, allow_none=True)
check_instance(plot_energy, IOParams, allow_none=True)
assert custom_input_buffer in (None, "B0", "B1", "auto"), custom_input_buffer
assert custom_output_buffer in (None, "B0", "B1", "auto"), custom_output_buffer
field = symbolic_transform.field
is_forward = symbolic_transform.is_forward
self._transform_group = transform_group
self._tag = tag
self._symbol = symbolic_transform
self._queue = None
self._custom_input_buffer = custom_input_buffer
self._custom_output_buffer = custom_output_buffer
self._matching_forward_transform = matching_forward_transform
self._action = action
self._do_dump_energy = (dump_energy is not None) and (
dump_energy.frequency >= 0
)
self._do_plot_energy = (plot_energy is not None) and (
plot_energy.frequency >= 0
)
compute_energy_frequencies = to_set(
first_not_None(compute_energy_frequencies, set())
)
if self._do_dump_energy:
compute_energy_frequencies.add(dump_energy.frequency)
if self._do_plot_energy:
compute_energy_frequencies.add(plot_energy.frequency)
compute_energy_frequencies = set(
filter(lambda f: f >= 0, compute_energy_frequencies)
)
do_compute_energy = len(compute_energy_frequencies) > 0
self._do_compute_energy = do_compute_energy
self._compute_energy_frequencies = compute_energy_frequencies
self._plot_energy_ioparams = plot_energy
self._dump_energy_ioparams = dump_energy
if self._do_compute_energy:
ename = "E{}_{}".format("f" if is_forward else "b", field.name)
pename = "E{}_{}".format("f" if is_forward else "b", field.pretty_name)
vename = "E{}_{}".format("f" if is_forward else "b", field.var_name)
self._energy_parameter = BufferParameter(
name=ename,
pretty_name=pename,
var_name=vename,
shape=None,
dtype=None,
initial_value=None,
)
else:
self._energy_parameter = None
self._energy_dumper = None
self._energy_plotter = None
if is_forward:
msg = "Cannot specify 'custom_input_buffer' for a forward transform."
assert custom_input_buffer is None, msg
msg = "Cannot specify 'matching_forward_transform' for a forward transform."
assert matching_forward_transform is None, msg
else:
msg = "Cannot specify 'custom_output_buffer' for a backward transform."
assert self._custom_output_buffer is None, msg
if self._custom_input_buffer == "auto":
msg = "Using 'auto' as 'custom_output_buffer' of a backward transform implies "
msg += "to specify a 'matching_forward_transform' to choose the buffer from."
assert matching_forward_transform is not None, msg
assert isinstance(
matching_forward_transform, PlannedSpectralTransform
), msg
assert matching_forward_transform.is_forward, msg
else:
msg = (
"Using 'custom_output_buffer' different than 'auto' for a backward "
)
msg += "transform implies to set 'matching_forward_transform' to None."
assert matching_forward_transform is None, msg
# reorder transforms in execution order (contiguous axe first)
transforms = self.s.transforms[::-1]
if len(transforms) != field.dim:
msg = "Number of transforms does not match field dimension."
raise ValueError(msg)
if all((tr is TransformType.NONE) for tr in transforms):
msg = "All transforms are of type NONE."
raise ValueError(msg)
if is_forward:
input_dtype = field.dtype
output_dtype = STU.determine_output_dtype(field.dtype, *transforms)
else:
input_dtype = STU.determine_input_dtype(field.dtype, *transforms)
output_dtype = field.dtype
self._input_dtype = np.dtype(input_dtype)
self._output_dtype = np.dtype(output_dtype)
self._input_shape = None
self._output_shape = None
self._input_buffer = None
self._output_buffer = None
self._dfield = None
self._input_symbolic_arrays = set()
self._output_symbolic_arrays = set()
self._ready = False
@property
def output_parameters(self):
return {self._energy_parameter} - {None}
[docs]
def input_symbolic_array(self, name, **kwds):
"""Create a symbolic array that will be bound to input transform array."""
assert "memory_object" not in kwds
assert "dim" not in kwds
obj = SymbolicArray(name=name, memory_object=None, dim=self.field.dim, **kwds)
self._input_symbolic_arrays.add(obj)
return obj
[docs]
def output_symbolic_array(self, name, **kwds):
"""Create a symbolic array that will be bound to output transform array."""
assert "memory_object" not in kwds
assert "dim" not in kwds
obj = SymbolicArray(name=name, memory_object=None, dim=self.field.dim, **kwds)
self._output_symbolic_arrays.add(obj)
return obj
@property
def transform_group(self):
return self._transform_group
@property
def op(self):
return self._transform_group.op
@property
def tag(self):
return self._tag
@property
def name(self):
return self._tag
@property
def symbol(self):
return self._symbol
@property
def s(self):
return self._symbol
@property
def field(self):
return self._symbol.field
@property
def is_forward(self):
return self._symbol.is_forward
@property
def is_backward(self):
return not self.is_forward
@property
def transforms(self):
return self._symbol.transforms
@property
def input_dtype(self):
return self._input_dtype
@property
def output_dtype(self):
return self._output_dtype
@property
def backend(self):
assert self.discretized
backend = self._backend
if backend is None:
msg = "backend has not been set yet."
raise AttributeError(msg)
return backend
@property
def dfield(self):
assert self.discretized
if self._dfield is None:
msg = "dfield has not been set."
raise AttributeError(msg)
return self._dfield
@property
def input_shape(self):
assert self.discretized
if self._input_shape is None:
msg = "input_shape has not been set."
raise AttributeError(msg)
return self._input_shape
@property
def output_shape(self):
assert self.discretized
if self._output_shape is None:
msg = "output_shape has not been set."
raise AttributeError(msg)
return self._output_shape
@property
def input_transform_shape(self):
assert self.discretized
if self._input_transform_shape is None:
msg = "input_transform_shape has not been set."
raise AttributeError(msg)
return self._input_transform_shape
@property
def output_transform_shape(self):
assert self.discretized
if self._output_transform_shape is None:
msg = "output_transform_shape has not been set."
raise AttributeError(msg)
return self._output_transform_shape
@property
def input_axes(self):
assert self.discretized
if self._input_axes is None:
msg = "input_axes has not been set."
raise AttributeError(msg)
return self._input_axes
@property
def output_axes(self):
assert self.discretized
if self._output_axes is None:
msg = "output_axes has not been set."
raise AttributeError(msg)
return self._output_axes
@property
def input_slices(self):
assert self.discretized
buf = self._input_slices
if buf is None:
msg = "input_slices has not been set yet."
raise AttributeError(msg)
return buf
@property
def output_slices(self):
assert self.discretized
buf = self._output_slices
if buf is None:
msg = "output_slices has not been set yet."
raise AttributeError(msg)
return buf
@property
def input_buffer(self):
assert self.discretized
buf = self._input_buffer
if buf is None:
msg = "input_buffer has not been set yet."
raise AttributeError(msg)
return buf
@property
def output_buffer(self):
assert self.discretized
buf = self._output_buffer
if buf is None:
msg = "output_buffer has not been set yet."
raise AttributeError(msg)
return buf
@property
def full_input_buffer(self):
assert self.discretized
buf = self._full_input_buffer
if buf is None:
msg = "full_input_buffer has not been set yet."
raise AttributeError(msg)
return buf
@property
def full_output_buffer(self):
assert self.discretized
buf = self._full_output_buffer
if buf is None:
msg = "full_output_buffer has not been set yet."
raise AttributeError(msg)
return buf
@property
def initialized(self):
return self.op.initialized
@property
def discretized(self):
return self.op.discretized
@property
def ready(self):
return self._ready
[docs]
@_initialized
def discretize(self, **kwds):
is_forward = self.is_forward
dim = self.field.dim
field_axes = TranspositionState[dim].default_axes()
if is_forward:
(dfield, transform_info, transpose_info, transform_offsets) = (
self._discretize_forward(field_axes, **kwds)
)
assert transpose_info[0][1] == field_axes
else:
(dfield, transform_info, transpose_info, transform_offsets) = (
self._discretize_backward(field_axes, **kwds)
)
assert transpose_info[-1][2] == field_axes
assert dfield.dim == len(transform_info) == len(transpose_info) == dim
assert transform_info[0][2][1] == self._input_dtype
assert transform_info[-1][3][1] == self._output_dtype
# filter out untransformed axes
tidx = tuple(
filter(lambda i: not STU.is_none(transform_info[i][1]), range(dim))
)
assert tidx, "Could not determine any transformed axe."
ntransforms = len(tidx)
transform_info = tuple(map(transform_info.__getitem__, tidx))
transpose_info = tuple(map(transpose_info.__getitem__, tidx))
assert len(transform_info) == len(transpose_info) == ntransforms
# determine input and output shapes
input_axes = transpose_info[0][1]
output_axes = transpose_info[-1][2]
if is_forward:
assert field_axes == input_axes, (field_axes, input_axes)
input_transform_shape = transpose_info[0][3]
output_transform_shape = transform_info[-1][3][0]
input_shape, input_slices, _ = self.determine_buffer_shape(
input_transform_shape, False, transform_offsets, input_axes
)
output_shape, output_slices, zfos = self.determine_buffer_shape(
output_transform_shape, True, transform_offsets, output_axes
)
# We have a situation where we should impose zeros:
# 1) output transform ghosts (when there are transform sizes mismatch DXT-I variants)
zero_fill_output_slices = zfos
else:
assert field_axes == output_axes, (field_axes, output_axes)
input_transform_shape = transform_info[0][2][0]
output_transform_shape = transpose_info[-1][4]
input_shape, input_slices, _ = self.determine_buffer_shape(
input_transform_shape, True, transform_offsets, input_axes
)
output_shape, output_slices, zfos = self.determine_buffer_shape(
output_transform_shape, False, transform_offsets, output_axes
)
# We have a situation where we should impose zeros:
# 1) impose homogeneous dirichlet conditions on output
# (implicit 0's are not part of the transform output).
zero_fill_output_slices = zfos
axes = output_axes if is_forward else input_axes
ptransforms = tuple(self.transforms[i] for i in axes)
self._permuted_transforms = ptransforms
if self._do_compute_energy:
shape = output_shape if is_forward else input_shape
# view = (output_slices if is_forward else input_slices)
assert len(shape) == ntransforms
shape = tuple(
Si - 2 if sum(transform_offsets[i]) == 2 else Si
for i, Si in zip(axes, shape)
)
K2 = ()
for tr, Ni in zip(ptransforms, shape):
Ki = Ni // 2 if STU.is_C2C(tr) else Ni - 1
K2 += (Ki * Ki,)
max_wavenumber = int(round(sum(K2) ** 0.5, 0))
energy_nbytes = compute_nbytes(max_wavenumber + 1, dfield.dtype)
if dfield.backend.kind == Backend.OPENCL:
mutexes_nbytes = compute_nbytes(max_wavenumber + 1, np.int32)
else:
mutexes_nbytes = 0
self._max_wavenumber = max_wavenumber
self._energy_nbytes = energy_nbytes
self._mutexes_nbytes = mutexes_nbytes
Ep = self._energy_parameter
Ep.reallocate_buffer(shape=(max_wavenumber + 1,), dtype=dfield.dtype)
fname = fname = "{}{}".format(dfield.name, "_in" if is_forward else "_out")
# build txt dumper
if self._do_dump_energy:
diop = self._dump_energy_ioparams
assert diop is not None
self._energy_dumper = EnergyDumper(
energy_parameter=Ep,
io_params=self._dump_energy_ioparams,
fname=fname,
)
# build plotter if required
if self._do_plot_energy:
piop = self._plot_energy_ioparams
assert piop is not None
pname = "{}.{}.{}".format(
self.op.__class__.__name__,
"forward" if is_forward else "backward",
dfield.pretty_name,
)
energy_parameters = {pname: self._energy_parameter}
self._energy_plotter = EnergyPlotter(
energy_parameters=energy_parameters,
io_params=self._plot_energy_ioparams,
fname=fname,
)
else:
self._max_wavenumber = None
self._energy_nbytes = None
self._mutexes_nbytes = None
self._dfield = dfield
self._transform_info = transform_info
self._transpose_info = transpose_info
self._ntransforms = ntransforms
self._input_axes = input_axes
self._input_shape = input_shape
self._input_slices = input_slices
self._input_transform_shape = input_transform_shape
self._output_axes = output_axes
self._output_shape = output_shape
self._output_slices = output_slices
self._output_transform_shape = output_transform_shape
self._zero_fill_output_slices = zero_fill_output_slices
self._backend = dfield.backend
if self.DEBUG:
def axis_format(info):
prefix = "\n" + " " * 4
ss = ""
for i, data in enumerate(info):
ss += prefix + f"{i}/ " + str(data)
return ss
def slc_format(slices):
if slices is None:
return "NONE"
else:
prefix = "\n" + " " * 4
ss = ""
for slc in slices:
ss += prefix + str(slc)
return ss
print(f"\n\n== SPECTRAL PLANNING INFO OF FIELD {dfield.pretty_name} ==")
print(
"transform direction: {}".format(
"FORWARD" if self.is_forward else "BACKWARD"
)
)
print(f"transforms: {self.transforms}")
print(":CARTESIAN INFO:")
print(f"cart shape: {dfield.topology.cart_shape}")
print(f"global grid resolution: {dfield.mesh.grid_resolution}")
print(f"local grid resolution: {dfield.compute_resolution}")
print(":INPUT:")
print(f"input axes: {self._input_axes}")
print(f"input dtype: {self._input_dtype}")
print(f"input transform shape: {self._input_transform_shape}")
print(f"input shape: {self._input_shape}")
print(f"input slices: {self._input_slices}")
print(":OUTPUT:")
print(f"output axes: {self._output_axes}")
print(f"output_dtype: {self._output_dtype}")
print(f"output transform shape: {self._output_transform_shape}")
print(f"output shape: {self._output_shape}")
print(f"output_slices: {self._output_slices}")
print(":TRANSFORM INFO:")
print(f"transform_info: {axis_format(transform_info)}")
print(":TRANSPOSE INFO:")
print(f"transpose_info: {axis_format(transpose_info)}")
print(":ZERO FILL:")
print(
f"zero_fill_output_slices: {slc_format(self._zero_fill_output_slices)}"
)
[docs]
def get_mapped_input_buffer(self):
return self.get_mapped_full_input_buffer()[self.input_slices]
[docs]
def get_mapped_output_buffer(self):
return self.get_mapped_full_output_buffer()[self.output_slices]
[docs]
def get_mapped_full_input_buffer(self):
dfield = self._dfield
if (
self.is_forward
and dfield.backend.kind == Backend.OPENCL
and self.transform_group._op.enable_opencl_host_buffer_mapping
):
return self.transform_group._op.get_mapped_object(dfield)[
dfield.compute_slices
]
else:
return self.full_input_buffer
[docs]
def get_mapped_full_output_buffer(self):
dfield = self._dfield
if (
self.is_backward
and dfield.backend.kind == Backend.OPENCL
and self.transform_group._op.enable_opencl_host_buffer_mapping
):
return self.transform_group._op.get_mapped_object(dfield)[
dfield.compute_slices
]
else:
return self.full_output_buffer
[docs]
def determine_buffer_shape(self, transform_shape, target_is_buffer, offsets, axes):
offsets = tuple(offsets[ai] for ai in axes)
slices = []
shape = []
zero_fill_slices = []
dim = len(axes)
for i, ((lo, ro), si) in enumerate(zip(offsets, transform_shape)):
if (lo ^ ro) and target_is_buffer:
Si = si
slc = slice(0, si)
else:
Si = si + lo + ro
slc = slice(lo, Si - ro)
if lo > 0:
zfill = [slice(None, None, None)] * dim
zfill[i] = slice(0, lo)
zfill = tuple(zfill)
zero_fill_slices.append(zfill)
if ro > 0:
zfill = [slice(None, None, None)] * dim
zfill[i] = slice(Si - ro, Si)
zfill = tuple(zfill)
zero_fill_slices.append(zfill)
shape.append(Si)
slices.append(slc)
return tuple(shape), tuple(slices), tuple(zero_fill_slices)
[docs]
def configure_input_buffer(self, buf):
input_dtype, input_shape = self.input_dtype, self.input_shape
buf_nbytes = compute_nbytes(buf.shape, buf.dtype)
input_nbytes = compute_nbytes(input_shape, input_dtype)
assert buf_nbytes >= input_nbytes, (buf_nbytes, input_nbytes)
if (buf.shape != input_shape) or (buf.dtype != input_dtype):
buf = (
buf.view(dtype=np.int8)[:input_nbytes]
.view(dtype=input_dtype)
.reshape(input_shape)
)
if isinstance(buf, Array):
buf = buf.handle
input_buffer = buf[self.input_slices]
assert input_buffer.shape == self.input_transform_shape
self._full_input_buffer = buf
self._input_buffer = input_buffer
for symbol in self._input_symbolic_arrays:
symbol.to_backend(self.backend.kind).bind_memory_object(buf)
return input_buffer
[docs]
def configure_output_buffer(self, buf):
output_dtype, output_shape = self.output_dtype, self.output_shape
buf_nbytes = compute_nbytes(buf.shape, buf.dtype)
output_nbytes = compute_nbytes(output_shape, output_dtype)
assert buf_nbytes >= output_nbytes, (buf_nbytes, output_nbytes)
if (buf.shape != output_shape) or (buf.dtype != output_dtype):
buf = (
buf.view(dtype=np.int8)[:output_nbytes]
.view(dtype=output_dtype)
.reshape(output_shape)
)
if isinstance(buf, Array):
buf = buf.handle
output_buffer = buf[self.output_slices]
assert output_buffer.shape == self.output_transform_shape
self._full_output_buffer = buf
self._output_buffer = output_buffer
for symbol in self._output_symbolic_arrays:
symbol.to_backend(self.backend.kind).bind_memory_object(buf)
return output_buffer
def _discretize_forward(self, field_axes, **kwds):
dfield = self.op.input_discrete_fields[self.field]
grid_resolution = dfield.mesh.grid_resolution
local_resolution = dfield.compute_resolution
input_dtype = dfield.dtype
dim = dfield.dim
forward_transforms = self.transforms[::-1]
backward_transforms = STU.get_inverse_transforms(*forward_transforms)
(resolution, transform_offsets) = STU.get_transform_resolution(
local_resolution, *forward_transforms
)
local_transform_info = self._determine_transform_info(
forward_transforms, resolution, input_dtype
)
local_transpose_info = self._determine_transpose_info(
field_axes, local_transform_info
)
local_transform_info = self._permute_transform_info(
local_transform_info, local_transpose_info
)
transform_info = local_transform_info
transpose_info = local_transpose_info
return (dfield, transform_info, transpose_info, transform_offsets)
def _discretize_backward(self, field_axes, **kwds):
forward_transforms = self.transforms[::-1]
backward_transforms = STU.get_inverse_transforms(*forward_transforms)
def reverse_transform_info(transform_info):
transform_info = list(transform_info)
for i, d in enumerate(transform_info):
d = list(d)
d[1] = forward_transforms[i]
d2, d3 = d[2:4]
d[2:4] = d3, d2
transform_info[i] = tuple(d)
transform_info = tuple(transform_info)
return transform_info[::-1]
def reverse_transpose_info(transpose_info):
transpose_info = list(transpose_info)
for i, d in enumerate(transpose_info):
if d[0] is not None:
d = list(d)
d1, d2, d3, d4 = d[1:5]
d[1:5] = d2, d1, d4, d3
d[0] = tuple(d[1].index(ai) for ai in d[2])
d = tuple(d)
else:
# no permutation
assert d[1] == d[2]
assert d[3] == d[4]
transpose_info[i] = d
return transpose_info[::-1]
dfield = self.op.output_discrete_fields[self.field]
grid_resolution = dfield.mesh.grid_resolution
local_resolution = dfield.compute_resolution
output_dtype = dfield.dtype
dim = dfield.dim
(resolution, transform_offsets) = STU.get_transform_resolution(
local_resolution, *backward_transforms
)
local_backward_transform_info = self._determine_transform_info(
backward_transforms, resolution, output_dtype
)
local_backward_transpose_info = self._determine_transpose_info(
field_axes, local_backward_transform_info
)
local_backward_transform_info = self._permute_transform_info(
local_backward_transform_info, local_backward_transpose_info
)
local_forward_transform_info = reverse_transform_info(
local_backward_transform_info
)
local_forward_transpose_info = reverse_transpose_info(
local_backward_transpose_info
)
transform_info = local_forward_transform_info
transpose_info = local_forward_transpose_info
return (dfield, transform_info, transpose_info, transform_offsets)
@classmethod
def _determine_transform_info(cls, transforms, src_shape, src_dtype):
transform_info = []
dim = len(transforms)
dst_shape, dst_dtype = src_shape, src_dtype
dst_view = [slice(0, si) for si in src_shape]
for i, tr in enumerate(transforms):
axis = i
src_shape = dst_shape
src_dtype = dst_dtype
src_view = dst_view
if STU.is_none(tr):
pass
elif STU.is_backward(tr):
msg = "{} is not a forward transform."
msg = msg.format(tr)
raise ValueError(msg)
elif STU.is_R2R(tr):
msg = f"Expected a floating point data type but got {src_dtype}."
assert is_fp(src_dtype), msg
# data type and shape does not change
elif STU.is_R2C(tr):
msg = f"Expected a floating point data type but got {src_dtype}."
assert is_fp(src_dtype), msg
dst_shape = list(src_shape)
dst_shape[dim - axis - 1] = dst_shape[dim - axis - 1] // 2 + 1
dst_shape = tuple(dst_shape)
dst_dtype = float_to_complex_dtype(src_dtype)
elif STU.is_C2C(tr):
msg = f"Expected a complex data type but got {src_dtype}."
assert is_complex(src_dtype), msg
# data type and shape does not change
else:
msg = f"Unknown transform type {tr}."
raise ValueError(msg)
(lo, ro) = STU.get_transform_offsets(tr)
src_view = src_view[:]
src_view[dim - axis - 1] = slice(lo, src_shape[dim - axis - 1] - ro)
dst_view = src_view[:]
dst_view[dim - axis - 1] = slice(lo, dst_shape[dim - axis - 1] - ro)
src_dtype = np.dtype(src_dtype)
dst_dtype = np.dtype(dst_dtype)
data = (
axis,
tr,
(src_shape, src_dtype, tuple(src_view)),
(dst_shape, dst_dtype, tuple(dst_view)),
)
transform_info.append(data)
transform_info = tuple(transform_info)
return transform_info
@classmethod
def _determine_transpose_info(cls, src_axes, transform_info):
transpose_info = []
dim = len(src_axes)
for (
axis,
tr,
(src_shape, src_dtype, src_view),
(dst_shape, dst_dtype, dst_view),
) in transform_info:
dst_axis = dim - 1 - axis
if (not STU.is_none(tr)) and (dst_axis != src_axes[-1]):
idx = src_axes.index(dst_axis)
dst_axes = list(src_axes)
dst_axes[idx] = src_axes[-1]
dst_axes[-1] = dst_axis
dst_axes = tuple(dst_axes)
permutation = tuple(src_axes.index(ai) for ai in dst_axes)
else:
dst_axes = src_axes
permutation = None
dst_shape = tuple(src_shape[ai] for ai in dst_axes)
src_shape = tuple(src_shape[ai] for ai in src_axes)
data = (permutation, src_axes, dst_axes, src_shape, dst_shape)
transpose_info.append(data)
src_axes = dst_axes
transpose_info = tuple(transpose_info)
return transpose_info
@classmethod
def _permute_transform_info(cls, transform_info, transpose_info):
assert len(transform_info) == len(transpose_info)
transform_info = list(transform_info)
for i, (transpose, transform) in enumerate(zip(transpose_info, transform_info)):
(_, _, dst_axes, _, transpose_out_shape) = transpose
(_1, _2, (src_shape, _3, src_view), (dst_shape, _4, dst_view)) = transform
permuted_src_shape = tuple(src_shape[ai] for ai in dst_axes)
permuted_src_view = tuple(src_view[ai] for ai in dst_axes)
permuted_dst_shape = tuple(dst_shape[ai] for ai in dst_axes)
permuted_dst_view = tuple(dst_view[ai] for ai in dst_axes)
assert permuted_src_shape == transpose_out_shape
transform = (
_1,
_2,
(permuted_src_shape, _3, permuted_src_view),
(permuted_dst_shape, _4, permuted_dst_view),
)
transform_info[i] = transform
transform_info = tuple(transform_info)
return transform_info
[docs]
@_discretized
def get_mem_requests(self, **kwds):
# first we need to find out src and dst buffers for transforms (B0 and B1)
nbytes = 0
for (
_,
_,
(src_shape, src_dtype, src_view),
(dst_shape, dst_dtype, dst_view),
) in self._transform_info:
nbytes = max(nbytes, compute_nbytes(src_shape, src_dtype))
nbytes = max(nbytes, compute_nbytes(dst_shape, dst_dtype))
nbytes = max(nbytes, compute_nbytes(self.input_shape, self.input_dtype))
nbytes = max(nbytes, compute_nbytes(self.output_shape, self.output_dtype))
# Then we need to find out the size of an additional tmp buffer
# we can only do it by creating temporary plans prior to setup
# with temporary buffers.
tmp_nbytes = 0
tg = self.transform_group
src = tg.FFTI.backend.empty(
shape=(nbytes,), dtype=np.uint8, min_alignment=tg.op.min_fft_alignment
)
dst = tg.FFTI.backend.empty(
shape=(nbytes,), dtype=np.uint8, min_alignment=tg.op.min_fft_alignment
)
queue = tg.FFTI.new_queue(tg=tg, name="tmp_queue")
for (
_,
tr,
(src_shape, src_dtype, src_view),
(dst_shape, dst_dtype, dst_view),
) in self._transform_info:
src_nbytes = compute_nbytes(src_shape, src_dtype)
dst_nbytes = compute_nbytes(dst_shape, dst_dtype)
b0 = src[:src_nbytes].view(dtype=src_dtype).reshape(src_shape)
b1 = dst[:dst_nbytes].view(dtype=dst_dtype).reshape(dst_shape)
fft_plan = tg.FFTI.get_transform(tr)(
a=b0.handle, out=b1.handle, axis=self.field.dim - 1, verbose=False
)
fft_plan.setup(queue=queue)
tmp_nbytes = max(tmp_nbytes, fft_plan.required_buffer_size)
del src
del dst
if tmp_nbytes > nbytes:
msg = "Planner claims to need more than buffer bytes as temporary buffer:"
msg += f"\n *Buffer bytes: {bytes2str(nbytes)}"
msg += f"\n *Tmp bytes: {bytes2str(tmp_nbytes)}"
warnings.warn(msg, HysopFFTWarning)
backend = self.transform_group.backend
mem_tag = self.transform_group.mem_tag
field_tag = self.dfield.name
kind = backend.kind
B0_tag = f"{mem_tag}_{kind}_B0"
B1_tag = f"{mem_tag}_{kind}_B1"
TMP_tag = f"{mem_tag}_{kind}_TMP"
ENERGY_tag = f"{mem_tag}_{kind}_ENERGY"
MUTEXES_tag = f"{mem_tag}_{kind}_MUTEXES"
self.B0_tag, self.B1_tag, self.TMP_tag, self.ENERGY_tag, self.MUTEXES_tag = (
B0_tag,
B1_tag,
TMP_tag,
ENERGY_tag,
MUTEXES_tag,
)
requests = {B0_tag: nbytes, B1_tag: nbytes, TMP_tag: tmp_nbytes}
if self._do_compute_energy:
if self._energy_nbytes > 0:
requests[ENERGY_tag] = self._energy_nbytes
if self._mutexes_nbytes > 0:
requests[MUTEXES_tag] = self._mutexes_nbytes
return requests
[docs]
@_discretized
def setup(self, work):
SETUP_DEBUG = False
assert not self.ready
dim = self.field.dim
op = self.op
tg = self.transform_group
FFTI = tg.FFTI
is_forward = self.is_forward
is_backward = self.is_backward
ntransforms = self._ntransforms
transform_info = self._transform_info
transpose_info = self._transpose_info
B0_tag, B1_tag = self.B0_tag, self.B1_tag
TMP_tag = self.TMP_tag
ENERGY_tag = self.ENERGY_tag
MUTEXES_tag = self.MUTEXES_tag
# get temporary buffers
(B0,) = work.get_buffer(op, B0_tag, handle=True)
(B1,) = work.get_buffer(op, B1_tag, handle=True)
assert is_byte_aligned(B0)
assert is_byte_aligned(B1)
try:
(TMP,) = work.get_buffer(op, TMP_tag, handle=True)
except ValueError:
TMP = None
if (self._energy_nbytes is not None) and (self._energy_nbytes > 0):
(ENERGY,) = work.get_buffer(op, ENERGY_tag, handle=True)
energy_buffer = ENERGY[: self._energy_nbytes].view(dtype=self.dfield.dtype)
assert energy_buffer.size == self._max_wavenumber + 1
else:
ENERGY = None
energy_buffer = None
if (self._mutexes_nbytes is not None) and (self._mutexes_nbytes > 0):
(MUTEXES,) = work.get_buffer(op, MUTEXES_tag, handle=True)
mutexes_buffer = MUTEXES[: self._mutexes_nbytes].view(dtype=np.int32)
assert mutexes_buffer.size == self._max_wavenumber + 1
else:
MUTEXES = None
mutexes_buffer = None
# Bind transformed field buffer to input or output.
# This only happens if the user did not bind another buffer prior to the setup.
dfield = self.dfield
if is_forward and (self._input_buffer is None):
self.configure_input_buffer(dfield.sbuffer[dfield.compute_slices])
elif is_backward and (self._output_buffer is None):
self.configure_output_buffer(dfield.sbuffer[dfield.compute_slices])
# bind group buffer to input or output if required.
custom_input_buffer = self._custom_input_buffer
custom_output_buffer = self._custom_output_buffer
if is_forward and custom_output_buffer:
if custom_output_buffer == "auto":
# will be determined and set later
pass
elif custom_output_buffer == "B0":
self.configure_output_buffer(B0)
elif custom_output_buffer == "B1":
self.configure_output_buffer(B1)
else:
msg = f"Unknown custom output buffer {custom_output_buffer}."
raise NotImplementedError(msg)
if is_backward and custom_input_buffer:
if custom_input_buffer == "auto":
assert self._matching_forward_transform.ready
custom_input_buffer = (
self._matching_forward_transform._custom_output_buffer
)
assert custom_input_buffer in ("B0", "B1")
if custom_input_buffer == "B0":
self.configure_input_buffer(B0)
elif custom_input_buffer == "B1":
self.configure_input_buffer(B1)
else:
msg = f"Unknown custom input buffer {custom_input_buffer}."
raise NotImplementedError(msg)
# define input and output buffer, as well as tmp buffers
src_buffer, dst_buffer = B0, B1
def nameof(buf):
assert (buf is B0) or (buf is B1)
if buf is B0:
return "B0"
else:
return "B1"
def check_size(buf, nbytes, name):
if buf.nbytes < nbytes:
msg = "Insufficient buffer size for buffer {} (shape={}, dtype={}).".format(
name, buf.shape, buf.dtype
)
msg += f"\nExpected at least {nbytes} bytes but got {buf.nbytes}."
try:
bname = nameof(buf)
msg += f"\nThis buffer has been identified as {bname}."
except:
pass
raise RuntimeError(msg)
# build spectral transform execution queue
qname = "fft_planner_{}_{}".format(
self.field.name, "forward" if is_forward else "backward"
)
queue = FFTI.new_queue(tg=self, name=qname)
if SETUP_DEBUG:
def print_op(description, category):
prefix = " |> "
print(f"{prefix}{description: <40}[{category}]")
msg = """
SPECTRAL TRANSFORM SETUP
op: {}
dim: {}
ntransforms: {}
group_tag: {}
is_forward: {}
is_backward: {}""".format(
op.pretty_tag, dim, ntransforms, self.tag, is_forward, is_backward
)
print(msg)
fft_plans = ()
for i in range(ntransforms):
transpose = transpose_info[i]
transform = transform_info[i]
(permutation, _, _, input_shape, output_shape) = transpose
(
_,
tr,
(src_shape, src_dtype, src_view),
(dst_shape, dst_dtype, dst_view),
) = transform
assert not STU.is_none(tr), "Got a NONE transform type."
is_first = i == 0
is_last = i == ntransforms - 1
should_forward_permute = is_forward and (permutation is not None)
should_backward_permute = is_backward and (permutation is not None)
if SETUP_DEBUG:
msg = f" TRANSFORM INDEX {i}:"
if permutation is not None:
msg += """
Transpose Info:
permutation: {}
input_shape: {}
output_shape: {}
forward_permute: {}
backward_permute: {}""".format(
permutation,
input_shape,
output_shape,
should_forward_permute,
should_backward_permute,
)
msg += """
Custom buffers:
custom_input: {}
custom output: {}
Transform Info:
SRC: shape {} and type {}, view {}
DST: shape {} and type {}, view {}
Planned Operations:""".format(
custom_input_buffer,
custom_output_buffer,
src_shape,
src_dtype,
src_view,
dst_shape,
dst_dtype,
dst_view,
)
print(msg)
src_nbytes = compute_nbytes(src_shape, src_dtype)
dst_nbytes = compute_nbytes(dst_shape, dst_dtype)
# build forward permutation if required
# (forward transforms transpose before actual transforms)
if should_forward_permute:
input_nbytes = compute_nbytes(input_shape, src_dtype)
output_nbytes = compute_nbytes(output_shape, src_dtype)
assert (
output_shape == src_shape
), "Transpose to Transform shape mismatch."
assert (
input_nbytes == output_nbytes
), "Transpose input and output size mismatch."
assert (
src_buffer.nbytes >= input_nbytes
), "Insufficient buffer size for src buf."
assert (
dst_buffer.nbytes >= output_nbytes
), "Insufficient buffer size for dst buf."
if is_first:
assert (
self.input_buffer.shape == input_shape
), "input_buffer shape mismatch."
assert (
self.input_buffer.dtype == src_dtype
), "input_buffer dtype mismatch."
b0 = self.get_mapped_input_buffer
else:
b0 = (
src_buffer[:input_nbytes]
.view(dtype=src_dtype)
.reshape(input_shape)
)
b1 = (
dst_buffer[:output_nbytes]
.view(dtype=src_dtype)
.reshape(output_shape)
)
queue += FFTI.plan_transpose(tg=tg, src=b0, dst=b1, axes=permutation)
if SETUP_DEBUG:
sfrom = "input_buffer" if is_first else nameof(src_buffer)
sto = nameof(dst_buffer)
print_op(
f"PlanTranspose(src={sfrom}, dst={sto}, permutation={permutation})",
"forward permute",
)
src_buffer, dst_buffer = dst_buffer, src_buffer
elif is_first:
assert (
self.input_buffer.shape == src_shape
), "input buffer shape mismatch."
assert (
self.input_buffer.dtype == src_dtype
), "input buffer dtype mismatch."
assert (
src_buffer.nbytes >= src_nbytes
), "Insufficient buffer size for src buf."
if (custom_input_buffer is not None) and (
nameof(src_buffer) == custom_input_buffer
):
src_buffer, dst_buffer = dst_buffer, src_buffer
b0 = src_buffer[:src_nbytes].view(dtype=src_dtype).reshape(src_shape)
queue += FFTI.plan_copy(tg=tg, src=self.get_mapped_input_buffer, dst=b0)
if SETUP_DEBUG:
sfrom = "input_buffer"
sto = nameof(src_buffer)
print_op(f"PlanCopy(src={sfrom}, dst={sto})", "pre-transform copy")
# build batched 1D transform in contiguous axis
check_size(src_buffer, src_nbytes, "src")
check_size(dst_buffer, dst_nbytes, "dst")
b0 = src_buffer[:src_nbytes].view(dtype=src_dtype).reshape(src_shape)
b1 = dst_buffer[:dst_nbytes].view(dtype=dst_dtype).reshape(dst_shape)
fft_plan = FFTI.get_transform(tr)(a=b0, out=b1, axis=dim - 1)
fft_plan.setup(queue=queue)
fft_plans += (fft_plan,)
queue += fft_plan
if SETUP_DEBUG:
sfrom = nameof(src_buffer)
sto = nameof(dst_buffer)
print_op(f"PlanTransform(src={sfrom}, dst={sto})", tr)
src_buffer, dst_buffer = dst_buffer, src_buffer
# build backward permutation if required
# (backward transforms transpose after actual transforms)
if should_backward_permute:
input_nbytes = compute_nbytes(input_shape, dst_dtype)
output_nbytes = compute_nbytes(output_shape, dst_dtype)
assert (
input_shape == dst_shape
), "Transform to Transpose shape mismatch."
assert (
input_nbytes == output_nbytes
), "Transpose input and output size mismatch."
assert (
src_buffer.nbytes >= input_nbytes
), "Insufficient buffer size for src buf."
assert (
dst_buffer.nbytes >= output_nbytes
), "Insufficient buffer size for dst buf."
b0 = (
src_buffer[:input_nbytes].view(dtype=dst_dtype).reshape(input_shape)
)
if is_last and (self._action is SpectralTransformAction.OVERWRITE):
assert (
self.output_buffer.shape == output_shape
), "output buffer shape mismatch."
assert (
self.output_buffer.dtype == dst_dtype
), "output buffer dtype mismatch."
b1 = self.get_mapped_output_buffer
else:
b1 = (
dst_buffer[:output_nbytes]
.view(dtype=dst_dtype)
.reshape(output_shape)
)
queue += FFTI.plan_transpose(tg=tg, src=b0, dst=b1, axes=permutation)
if SETUP_DEBUG:
sfrom = nameof(src_buffer)
sto = "output_buffer" if is_last else nameof(dst_buffer)
print_op(
f"PlanTranspose(src={sfrom}, dst={sto})", "backward permute"
)
src_buffer, dst_buffer = dst_buffer, src_buffer
if is_last and (self._action is not SpectralTransformAction.OVERWRITE):
if self._action is SpectralTransformAction.ACCUMULATE:
assert (
self.output_buffer.shape == output_shape
), "output buffer shape mismatch."
assert (
self.output_buffer.dtype == dst_dtype
), "output buffer dtype mismatch."
queue += FFTI.plan_accumulate(
tg=tg, src=b1, dst=self.get_mapped_output_buffer
)
if SETUP_DEBUG:
sfrom = nameof(dst_buffer)
sto = "output_buffer"
print_op(
f"PlanAccumulate(src={sfrom}, dst={sto})",
"post-transform accumulate",
)
else:
msg = f"Unsupported action {self._action}."
raise NotImplementedError(msg)
elif is_last:
if custom_output_buffer is not None:
if custom_output_buffer not in ("B0", "B1", "auto"):
msg = f"Unknown custom output buffer {custom_output_buffer}."
raise NotImplementedError(msg)
elif custom_output_buffer == "auto":
custom_output_buffer = nameof(dst_buffer)
self._custom_output_buffer = custom_output_buffer
if custom_output_buffer == "B0":
self.configure_output_buffer(B0)
elif custom_output_buffer == "B1":
self.configure_output_buffer(B1)
else:
raise RuntimeError
elif nameof(src_buffer) == custom_output_buffer:
# This is a special case where we need to copy back and forth
# (because of offsets)
b0 = (
src_buffer[:dst_nbytes]
.view(dtype=dst_dtype)
.reshape(dst_shape)
)
b1 = (
dst_buffer[:dst_nbytes]
.view(dtype=dst_dtype)
.reshape(dst_shape)
)
queue += FFTI.plan_copy(tg=tg, src=b0, dst=b1)
if SETUP_DEBUG:
sfrom = nameof(src_buffer)
sto = nameof(dst_buffer)
print_op(
f"PlanCopy(src={sfrom}, dst={sto})",
"post-transform copy",
)
src_buffer, dst_buffer = dst_buffer, src_buffer
assert (
self.output_buffer.shape == dst_shape
), "output buffer shape mismatch."
assert (
self.output_buffer.dtype == dst_dtype
), "output buffer dtype mismatch."
assert (
src_buffer.nbytes >= dst_nbytes
), "Insufficient buffer size for src buf."
b0 = src_buffer[:dst_nbytes].view(dtype=dst_dtype).reshape(dst_shape)
if self._action is SpectralTransformAction.OVERWRITE:
pname = "PlanCopy"
pdes = "post-transform-copy"
queue += FFTI.plan_copy(
tg=tg, src=b0, dst=self.get_mapped_output_buffer
)
elif self._action is SpectralTransformAction.ACCUMULATE:
pname = "PlanAccumulate"
pdes = "post-transform-accumulate"
queue += FFTI.plan_accumulate(
tg=tg, src=b0, dst=self.get_mapped_output_buffer
)
else:
msg = f"Unsupported action {self._action}."
raise NotImplementedError(msg)
if SETUP_DEBUG:
sfrom = nameof(src_buffer)
sto = (
"output_buffer"
if (custom_output_buffer is None)
else custom_output_buffer
)
print_op(f"{pname}(src={sfrom}, dst={sto})", pdes)
if self._zero_fill_output_slices:
buf = self.get_mapped_full_output_buffer
slcs = self._zero_fill_output_slices
queue += FFTI.plan_fill_zeros(tg=tg, a=buf, slices=slcs)
if SETUP_DEBUG:
print_op("PlanFillZeros(dst=output_buffer)", "post-transform-callback")
# allocate fft plans
FFTI.allocate_plans(op, fft_plans, tmp_buffer=TMP)
# build kernels to compute energy if required
if self._do_compute_energy:
field_buffer = self.input_buffer if self.is_forward else self.output_buffer
spectral_buffer = (
self.output_buffer if self.is_forward else self.input_buffer
)
compute_energy_queue = FFTI.new_queue(tg=self, name="dump_energy")
compute_energy_queue += FFTI.plan_fill_zeros(
tg=tg, a=energy_buffer, slices=(Ellipsis,)
)
if mutexes_buffer is not None:
unlock_mutexes = FFTI.plan_fill_zeros(
tg=tg, a=mutexes_buffer, slices=(Ellipsis,)
)
compute_energy_queue += unlock_mutexes
compute_energy_queue().wait() # we need this before compute energy to unlock mutexes
compute_energy_queue += FFTI.plan_compute_energy(
tg=tg,
fshape=field_buffer.shape,
src=spectral_buffer,
dst=energy_buffer,
transforms=self._permuted_transforms,
mutexes=mutexes_buffer,
)
compute_energy_queue += FFTI.plan_copy(
tg=tg, src=energy_buffer, dst=self._energy_parameter._value
)
else:
compute_energy_queue = None
self._frequency_ioparams = tuple(
self.io_params.clone(frequency=f, with_last=True)
for f in self._compute_energy_frequencies
)
self._queue = queue
self._compute_energy_queue = compute_energy_queue
self._ready = True
def __call__(self, **kwds):
assert self._ready
assert self._queue is not None
evt = self._pre_transform_actions(**kwds)
evt = self._queue.execute(wait_for=evt)
evt = self._post_transform_actions(wait_for=evt, **kwds)
return evt
def _pre_transform_actions(self, simulation=None, wait_for=None, **kwds):
evt = wait_for
if simulation is False:
return evt
if self.is_backward and self._do_compute_energy:
evt = self.compute_energy(simulation=simulation, wait_for=evt)
if self._do_plot_energy:
evt = self.plot_energy(simulation=simulation, wait_for=evt)
return evt
def _post_transform_actions(self, simulation=None, wait_for=None, **kwds):
evt = wait_for
if simulation is False:
return evt
if self.is_forward and self._do_compute_energy:
evt = self.compute_energy(simulation=simulation, wait_for=evt)
if self._do_plot_energy:
evt = self.plot_energy(simulation=simulation, wait_for=evt)
return evt
[docs]
def compute_energy(self, simulation, wait_for):
msg = f"No simulation was passed in {type(self)}.__call__()."
assert simulation is not None, msg
evt = wait_for
should_compute_energy = any(
iop.should_dump(simulation=simulation) for iop in self._frequency_ioparams
)
if should_compute_energy:
evt = self._compute_energy_queue(wait_for=evt)
if self._do_dump_energy:
self._energy_dumper.update(simulation=simulation, wait_for=evt)
return evt
[docs]
def plot_energy(self, simulation, wait_for):
msg = f"No simulation was passed in {type(self)}.__call__()."
assert simulation is not None, msg
evt = wait_for
self._energy_plotter.update(simulation=simulation, wait_for=evt)
return wait_for